package tcode

import (
	"context"
	"crypto/rand"
	"database/sql"
	"errors"
	"fmt"
	"log"
	"math/big"
	"regexp"
	"strconv"
	"time"
)

func init() {
	log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile)
}

const (
	defaultAbsDir    = "./"
	apostrophe       = "`"
	singleQ          = "'"
	codeTemplateText = `
package {{.PackageName}}

import (
	"gitee.com/yan-shi-kun/tcode"
	{{range $index, $value := .ImportPackages}}
        "{{$value}}" {{end}}
)
const TableName{{.Name}} = "{{.DbName}}.{{.TableName}}"

// {{.Name}} {{.Comment}}
type {{.Name}} struct {
{{range $index, $value := .Fields}}
    {{$value.Name}}     {{$value.Type}} ` + apostrophe + `json:"{{$value.JsonName}}"` + apostrophe + ` // {{$value.Name}} {{$value.Comment}}{{end}}
}
func (receiver *{{.Name}}) TableName() string {
	return TableName{{.Name}}
}
func (receiver *{{.Name}}) RawColumnContainer(columns ...string) []any {
	container := make([]any, len(columns))
	for i := range columns {
		switch columns[i] { {{range $index, $value := .Fields}}
		case "{{$value.ColumnName}}":
			container[i] = &receiver.{{$value.Name}}{{end}}
		default :
			container[i] = new(any)
			tcode.FuncLog("warn: not matched to column")
		}
	}
	return container
}
func (receiver *{{.Name}}) NewTable() tcode.Table {
	return &{{.Name}}{ {{range $index, $value := .Fields}}
	{{$value.Name}}: receiver.{{$value.Name}},{{end}}
	}
}
func (receiver *{{.Name}}) NewInstance() tcode.Table {
	return &{{.Name}}{}
}
func (receiver *{{.Name}}) CopyFrom(src tcode.Table) {
	if source, ok := src.(*{{.Name}}); ok { {{range $index, $value := .Fields}}
		receiver.{{$value.Name}} = source.{{$value.Name}}{{end}}
	}
}

var ColumnNames{{.Name}} = []string{ {{range $index, $value := .Fields}} "{{$value.ColumnName}}",{{end}} }

func (receiver *{{.Name}}) Columns() []string {
	return ColumnNames{{.Name}}
}
`
)

var (
	dataTypeMap    = make(map[string]dataTypeFun)
	listenField    func(filedInfo *fieldInfo)
	findNumReg     = regexp.MustCompile(`\d+`)
	findInReg      = regexp.MustCompile(`\s+(?i)in\s*\(?\s*\?\s*\)+`)
	FuncLog        = defaultLog
	FuncSQLLog     = defaultSqlLog
	FuncGenId      = genId
	txConnKey      = confString("txConnKey" + genId())
	txOptionsKey   = confString("txOptionsKey" + genId())
	confKey        = confString("confKey" + genId())
	defaultConfKey confString
)

func WithConf(ctx context.Context, conf confString) context.Context {
	return context.WithValue(ctx, confKey, conf)
}
func WithTX(ctx context.Context, tx *sql.Tx) context.Context {
	return context.WithValue(ctx, txConnKey, tx)
}
func WithTxOptions(ctx context.Context, txo *sql.TxOptions) context.Context {
	return context.WithValue(ctx, txOptionsKey, txo)
}

// GetContextDBConn 从上下文中获取链接
func GetContextDBConn(ctx context.Context) *sql.DB {
	return getContextConf(ctx).DB
}

// GetContextTxConn 从上下文中获取事务
func GetContextTxConn(ctx context.Context) *sql.Tx {
	value := ctx.Value(txConnKey)
	if value == nil {
		return nil
	}
	if tx, ok := value.(*sql.Tx); ok {
		return tx
	}
	return nil
}

// GetContextTxOptions 从上下文中获取配置事务参数
func GetContextTxOptions(ctx context.Context) *sql.TxOptions {
	value := ctx.Value(txOptionsKey)
	if value == nil {
		return nil
	}
	if tx, ok := value.(*sql.TxOptions); ok {
		return tx
	}
	return getContextConf(ctx).DefaultTxOptions
}

func getContextConf(ctx context.Context) *Config {
	value := ctx.Value(confKey)
	if value == nil {
		return confMap[defaultConfKey]
	}
	if cfgKey, ok := value.(confString); ok {
		return confMap[cfgKey]
	}
	return nil
}

// New 创建一个tcode
func New(cfg *Config) (actuators confString, codeConstructor CodeFactory, err error) {
	if cfg == nil {
		err = errors.New("cfg is nil")
		FuncLog(err)
		return actuators, codeConstructor, err
	}
	refineDefaultConf(cfg)

	actuators = confString("tcode" + genId())
	if cfg.DB == nil {
		err = initDBConn(cfg)
		if err != nil {
			return actuators, codeConstructor, err
		}
	}
	switch cfg.Dialect {
	case "mysql":
		codeConstructor = new(mysqlDbParse)
	case "oracle":
		codeConstructor = new(oracleDbParse)
	default:
		FuncLog("err:not find", cfg.Dialect, "implement")
	}
	confMap[actuators] = cfg
	defaultConfKey = actuators
	return actuators, codeConstructor, nil
}

// refineDefaultConf 完善默认配置
func refineDefaultConf(cfg *Config) {
	if len(cfg.PrimaryKeyColumnName) <= 0 {
		cfg.PrimaryKeyColumnName = "id"
	}
	if len(cfg.PackageName) <= 0 {
		cfg.PackageName = "tables"
	}
}

func initDBConn(cfg *Config) error {
	var err error
	cfg.DB, err = sql.Open(cfg.DriverName, cfg.Dsn)
	if err != nil {
		FuncLog(err)
		return err
	}
	cfg.DB.SetConnMaxLifetime(time.Second * time.Duration(cfg.ConnMaxLifetimeSecond))
	cfg.DB.SetMaxOpenConns(cfg.MaxOpenConns)
	cfg.DB.SetMaxIdleConns(cfg.MaxIdleConns)
	err = cfg.DB.Ping()
	if err != nil {
		FuncLog(err)
		return err
	}
	return nil
}

// reTryScan 效率较低，用于处理数据列可能存在（nil）空值时,默认不使用该方法，设置 sqlInfo.possibleQueryEmptyColumn=true时会使用该方法
func reTryScan(container []any, rows *sql.Rows) error {
	//处理空数据nil报错
	//如果某一列处理错误，则保存源指针，使用new(any)进行处理，在方法返回前恢复到container内
	errContainer := make(map[int]any)
	defer func() {
		for i := range errContainer {
			container[i] = errContainer[i]
		}
	}()
	ci := len(container)
	var err error
	var preErrIndex = -1
	for i := 0; i < ci; i++ {
		err = rows.Scan(container...)
		if err == nil { //scan成功直接结束
			return nil
		}
		if err == sql.ErrNoRows {
			FuncLog(err)
			return err
		}
		errIndexStr := findNumReg.FindString(err.Error())
		errIndex, atoiErr := strconv.Atoi(errIndexStr)
		if atoiErr != nil {
			FuncLog(atoiErr)
			return atoiErr
		}
		if preErrIndex == errIndex {
			FuncLog(err)
			return err
		}
		if i <= errIndex && errIndex < ci {
			preErrIndex = errIndex //记录上一次的错误下标，处理死循环
			i = errIndex           //从错误的列下标开始处理
		} else {
			FuncLog(err)
			return err
		}
		//保留原始指针
		errContainer[i] = container[i]
		//替换指针类型
		container[i] = new(any)
	}
	return nil
}

// AddDataTypeFunc 添加一个自定义类型转换函数，优先级高于 columnTypeParse 函数内所列举的类型
func AddDataTypeFunc(dataType string, fun dataTypeFun) {
	if len(dataType) <= 0 || fun == nil {
		return
	}
	dataTypeMap[dataType] = fun
}

// ListenField 监听所有处理后的字段，并可直接根据需求修改信息
func ListenField(lf func(filedInfo *fieldInfo)) {
	listenField = lf
}

// StringInSlice 字符串切片中是否存在s
func StringInSlice(s string, slice []string) bool {
	return StringInIndex(s, slice) != -1
}

// StringInIndex 字符串切片在切片中的下标
func StringInIndex(s string, slice []string) int {
	for i := range slice {
		if slice[i] == s {
			return i
		}
	}
	return -1
}

// Hump 处理驼峰命名格式（字段名到属性名的转换）
func Hump(column string) (filed string) {
	cl := len(column)
	if cl == 0 {
		return filed
	}
	bytes := []byte(column)
	underlineCount := 0
	b := bytes[0]
	if b >= 97 && b <= 122 {
		bytes[0] = bytes[0] - 32
	}
	for i := 1; (i + underlineCount) < cl; i++ {
		b = bytes[i+underlineCount]
		if b == 95 || b == 45 {
			underlineCount++
			if i+underlineCount >= cl {
				break
			}
			b = bytes[i+underlineCount]
			if b >= 97 && b <= 122 {
				bytes[i] = bytes[i+underlineCount] - 32
			} else {
				if underlineCount > 0 {
					bytes[i] = bytes[i+underlineCount]
				}
			}
		} else {
			if b >= 65 && b <= 90 {
				bytes[i] = bytes[i+underlineCount] + 32
			} else {
				if underlineCount > 0 {
					bytes[i] = bytes[i+underlineCount]
				}
			}
		}
	}
	filed = string(bytes[0 : cl-underlineCount])
	return filed
}

// PrefixLower 首字母小写
func PrefixLower(s string) string {
	if len(s) <= 0 {
		return s
	}
	bytes := []byte(s)
	b := bytes[0]
	if b >= 65 && b <= 90 {
		bytes[0] = bytes[0] + 32
	}
	return string(bytes)
}

// PrefixUpper 首字母大写
func PrefixUpper(s string) string {
	if len(s) <= 0 {
		return s
	}
	bytes := []byte(s)
	b := bytes[0]
	if b >= 97 && b <= 122 {
		bytes[0] = bytes[0] - 32
	}
	return string(bytes)
}

// ConvertToString 将任意数据类型转换成string
func ConvertToString(val any) string {
	for {
		switch p := val.(type) {
		case bool:
			return strconv.FormatBool(p)
		case string:
			return p
		case int8:
			return strconv.FormatInt(int64(p), 10)
		case int16:
			return strconv.FormatInt(int64(p), 10)
		case int:
			return strconv.FormatInt(int64(p), 10)
		case int32:
			return strconv.FormatInt(int64(p), 10)
		case int64:
			return strconv.FormatInt(p, 10)
		case uint8:
			return strconv.FormatUint(uint64(p), 10)
		case uint16:
			return strconv.FormatUint(uint64(p), 10)
		case uint:
			return strconv.FormatUint(uint64(p), 10)
		case uint32:
			return strconv.FormatUint(uint64(p), 10)
		case uint64:
			return strconv.FormatUint(p, 10)
		case []byte:
			return string(p)
		case float32:
			return strconv.FormatFloat(float64(p), 'f', -1, 32)
		case float64:
			return strconv.FormatFloat(p, 'f', -1, 64)
		case time.Time:
			return p.Format("2006-01-02 15:04:05")
		case big.Float:
			return p.String()
		case *bool:
			return strconv.FormatBool(*p)
		case *string:
			return *p
		case *int8:
			return strconv.FormatInt(int64(*p), 10)
		case *int16:
			return strconv.FormatInt(int64(*p), 10)
		case *int:
			return strconv.FormatInt(int64(*p), 10)
		case *int32:
			return strconv.FormatInt(int64(*p), 10)
		case *int64:
			return strconv.FormatInt(*p, 10)
		case *uint8:
			return strconv.FormatUint(uint64(*p), 10)
		case *uint16:
			return strconv.FormatUint(uint64(*p), 10)
		case *uint:
			return strconv.FormatUint(uint64(*p), 10)
		case *uint32:
			return strconv.FormatUint(uint64(*p), 10)
		case *uint64:
			return strconv.FormatUint(*p, 10)
		case *[]byte:
			return string(*p)
		case *float32:
			return strconv.FormatFloat(float64(*p), 'f', -1, 32)
		case *float64:
			return strconv.FormatFloat(*p, 'f', -1, 64)
		case *time.Time:
			return p.Format("2006-01-02 15:04:05")
		case *big.Float:
			return p.String()
		case nil:
			return ""
		case *any:
			val = *p
		default:
			FuncLog("warn: unknown match params type")
			return fmt.Sprintf("%v", p)
		}
	}
}

// ConvertToStringSlice 将任意数据类型转换成string slice
func ConvertToStringSlice(val any) []string {
	switch p := val.(type) {
	case []int:
		var strings []string
		for i := range p {
			strings = append(strings, strconv.FormatInt(int64(p[i]), 10))
		}
		return strings
	case []int8:
		var strings []string
		for i := range p {
			strings = append(strings, strconv.FormatInt(int64(p[i]), 10))
		}
		return strings
	case []int16:
		var strings []string
		for i := range p {
			strings = append(strings, strconv.FormatInt(int64(p[i]), 10))
		}
		return strings
	case []int32:
		var strings []string
		for i := range p {
			strings = append(strings, strconv.FormatInt(int64(p[i]), 10))
		}
		return strings
	case []int64:
		var strings []string
		for i := range p {
			strings = append(strings, strconv.FormatInt(p[i], 10))
		}
		return strings
	case []uint8:
		var strings []string
		for i := range p {
			strings = append(strings, strconv.FormatUint(uint64(p[i]), 10))
		}
		return strings
	case []uint16:
		var strings []string
		for i := range p {
			strings = append(strings, strconv.FormatUint(uint64(p[i]), 10))
		}
		return strings
	case []uint:
		var strings []string
		for i := range p {
			strings = append(strings, strconv.FormatUint(uint64(p[i]), 10))
		}
		return strings
	case []uint32:
		var strings []string
		for i := range p {
			strings = append(strings, strconv.FormatUint(uint64(p[i]), 10))
		}
		return strings
	case []uint64:
		var strings []string
		for i := range p {
			strings = append(strings, strconv.FormatUint(p[i], 10))
		}
		return strings
	case []float32:
		var strings []string
		for i := range p {
			strings = append(strings, strconv.FormatFloat(float64(p[i]), 'f', -1, 32))
		}
		return strings
	case []float64:
		var strings []string
		for i := range p {
			strings = append(strings, strconv.FormatFloat(float64(p[i]), 'f', -1, 64))
		}
		return strings
	case []time.Time:
		var strings []string
		for i := range p {
			strings = append(strings, p[i].Format("2006-01-02 15:04:05"))
		}
		return strings
	case []big.Float:
		var strings []string
		for i := range p {
			strings = append(strings, p[i].String())
		}
		return strings
	default:
		return []string{ConvertToString(p)}
	}
}

// GetPkColumnName 获取表的主键列名
func GetPkColumnName(ctx context.Context) string {
	return getContextConf(ctx).PrimaryKeyColumnName
}

// FuncGenId 生成唯一id
func genId() string {
	timestamp := time.Now().UnixNano()
	randNumber, _ := rand.Int(rand.Reader, big.NewInt(999999))
	return fmt.Sprintf("%d%06d", timestamp, randNumber)
}
func defaultLog(arbitrarily ...any) {
	_ = log.Output(2, fmt.Sprintf("\x1b[0;%dm%s\x1b[0m", 36, arbitrarily))
}

func defaultSqlLog(ms float64, sql string, params []any) {
	_ = log.Output(2, fmt.Sprintf("\x1b[0;%dm%f ms -> %s,%v\x1b[0m", 36, ms, sql, params))
}
